{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据集加载"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "MindSpore支持加载图像领域常用的数据集，用户可以直接使用`mindspore.dataset`中对应的类实现数据集的加载。目前支持的常用数据集及对应的数据集类如下表所示。\n",
    "\n",
    "|  图像数据集    | 数据集类  | 数据集简介 |\n",
    "|  :----       | :----    | :----    |\n",
    "| MNIST               | MnistDataset | MNIST是一个大型手写数字图像数据集，拥有60,000张训练图像和10,000张测试图像，常用于训练各种图像处理系统。 |\n",
    "| CIFAR-10          | Cifar10Dataset | CIFAR-10是一个微小图像数据集，包含10种类别下的60,000张32x32大小彩色图像，平均每种类别6,000张，其中5,000张为训练集，1,000张为测试集。 |\n",
    "| CIFAR-100       | Cifar100Dataset | CIFAR-100与CIFAR-10类似，但拥有100种类别，平均每种类别600张，其中500张为训练集，100张为测试集。 |\n",
    "| CelebA              | CelebADataset | CelebA是一个大型人脸图像数据集，包含超过200,000张名人人脸图像，每张图像拥有40个特征标记。 |\n",
    "| PASCAL-VOC  | VOCDataset | PASCAL-VOC是一个常用图像数据集，被广泛用于目标检测、图像分割等计算机视觉领域。 |\n",
    "| COCO                | CocoDataset | COCO是一个大型目标检测、图像分割、姿态估计数据集。 |\n",
    "| CLUE                | CLUEDataset | CLUE是一个大型中文语义理解数据集。 |\n",
    "\n",
    "MindSpore还支持加载多种数据存储格式下的数据集，用户可以直接使用`mindspore.dataset`中对应的类加载磁盘中的数据文件。目前支持的数据格式及对应加载方式如下表所示。\n",
    "\n",
    "|  数据格式    | 数据集类  | 数据格式简介 |\n",
    "|  :----                    | :----  | :----           |\n",
    "| MindRecord | MindDataset | MindRecord是MindSpore的自研数据格式，具有读写高效、易于分布式处理等优势。 |\n",
    "| Manifest | ManifestDataset | Manifest是华为ModelArts支持的一种数据格式，描述了原始文件和标注信息，可用于标注、训练、推理场景。 |\n",
    "| TFRecord | TFRecordDataset | TFRecord是TensorFlow定义的一种二进制数据文件格式。 |\n",
    "| NumPy | NumpySlicesDataset | NumPy数据源指的是已经读入内存中的NumPy arrays格式数据集。 |\n",
    "| Text File | TextFileDataset | Text File指的是常见的文本格式数据。 |\n",
    "| CSV File | CSVDataset | CSV指逗号分隔值，其文件以纯文本形式存储表格数据。 |\n",
    "\n",
    "MindSpore也同样支持使用`GeneratorDataset`自定义数据集的加载方式，用户可以根据需要实现自己的数据集类。\n",
    "\n",
    "> 更多详细的数据集加载接口说明，参见[API文档](https://www.mindspore.cn/doc/api_python/zh-CN/master/mindspore/mindspore.dataset.html)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 常用数据集加载\n",
    "\n",
    "下面将介绍几种常用数据集的加载方式。\n",
    "\n",
    "### CIFAR-10/100数据集\n",
    "\n",
    "下载CIFAR-10数据集并解压到指定位置："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-12-09 20:21:42--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/cifar10.zip\n",
      "Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172\n",
      "Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 166235630 (159M) [application/zip]\n",
      "Saving to: ‘cifar10.zip’\n",
      "\n",
      "cifar10.zip         100%[===================>] 158.53M  70.0MB/s    in 2.3s    \n",
      "\n",
      "2020-12-09 20:21:44 (70.0 MB/s) - ‘cifar10.zip’ saved [166235630/166235630]\n",
      "\n",
      "Archive:  ./cifar10.zip\n",
      "   creating: ./datasets/cifar10/\n",
      "   creating: ./datasets/cifar10/test/\n",
      "  inflating: ./datasets/cifar10/test/test_batch.bin  \n",
      "   creating: ./datasets/cifar10/train/\n",
      "  inflating: ./datasets/cifar10/train/batches.meta.txt  \n",
      "  inflating: ./datasets/cifar10/train/data_batch_1.bin  \n",
      "  inflating: ./datasets/cifar10/train/data_batch_2.bin  \n",
      "  inflating: ./datasets/cifar10/train/data_batch_3.bin  \n",
      "  inflating: ./datasets/cifar10/train/data_batch_4.bin  \n",
      "  inflating: ./datasets/cifar10/train/data_batch_5.bin  \n",
      "./datasets/cifar10/\n",
      "├── test\n",
      "│   └── test_batch.bin\n",
      "└── train\n",
      "    ├── batches.meta.txt\n",
      "    ├── data_batch_1.bin\n",
      "    ├── data_batch_2.bin\n",
      "    ├── data_batch_3.bin\n",
      "    ├── data_batch_4.bin\n",
      "    └── data_batch_5.bin\n",
      "\n",
      "2 directories, 7 files\n"
     ]
    }
   ],
   "source": [
    "!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/cifar10.zip\n",
    "!unzip -o ./cifar10.zip -d ./datasets/\n",
    "!tree ./datasets/cifar10/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的样例通过`Cifar10Dataset`接口加载CIFAR-10数据集，使用顺序采样器获取其中5个样本，然后展示了对应图片的形状和标签。\n",
    "\n",
    "CIFAR-100数据集和MNIST数据集的加载方式也与之类似。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 4\n",
      "Image shape: (32, 32, 3) , Label: 1\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "DATA_DIR = \"./datasets/cifar10/train/\"\n",
    "\n",
    "sampler = ds.SequentialSampler(num_samples=5)\n",
    "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### VOC数据集\n",
    "\n",
    "VOC数据集有多个版本，此处以VOC2012为例。下载[VOC2012数据集](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar)并解压，目录结构如下。\n",
    "\n",
    "```\n",
    "└─ VOCtrainval_11-May-2012\n",
    "    └── VOCdevkit\n",
    "        └── VOC2012\n",
    "            ├── Annotations\n",
    "            ├── ImageSets\n",
    "            ├── JPEGImages\n",
    "            ├── SegmentationClass\n",
    "            └── SegmentationObject\n",
    "```\n",
    "\n",
    "下面的样例通过`VOCDataset`接口加载VOC2012数据集，分别演示了将任务指定为分割（Segmentation）和检测（Detection）时的原始图像形状和目标形状。\n",
    "\n",
    "```python\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "DATA_DIR = \"VOCtrainval_11-May-2012/VOCdevkit/VOC2012/\"\n",
    "\n",
    "dataset = ds.VOCDataset(DATA_DIR, task=\"Segmentation\", usage=\"train\", num_samples=2, decode=True, shuffle=False)\n",
    "\n",
    "print(\"[Segmentation]:\")\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"image shape:\", data[\"image\"].shape)\n",
    "    print(\"target shape:\", data[\"target\"].shape)\n",
    "\n",
    "dataset = ds.VOCDataset(DATA_DIR, task=\"Detection\", usage=\"train\", num_samples=1, decode=True, shuffle=False)\n",
    "\n",
    "print(\"[Detection]:\")\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"image shape:\", data[\"image\"].shape)\n",
    "    print(\"bbox shape:\", data[\"bbox\"].shape)\n",
    "```\n",
    "\n",
    "输出结果：\n",
    "\n",
    "```text\n",
    "[Segmentation]:\n",
    "image shape: (2268, 4032, 3)\n",
    "target shape: (680, 1209, 3)\n",
    "image shape: (2268, 4032, 3)\n",
    "target shape: (680, 1209, 3)\n",
    "[Detection]:\n",
    "image shape: (2268, 4032, 3)\n",
    "bbox shape: (3, 4)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### COCO数据集\n",
    "\n",
    "COCO数据集有多个版本，此处以COCO2017的验证数据集为例。下载COCO2017的[验证集](http://images.cocodataset.org/zips/val2017.zip)、[检测任务标注](http://images.cocodataset.org/annotations/annotations_trainval2017.zip)和[全景分割任务标注](http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip)并解压，只取其中的验证集部分，按以下目录结构存放。\n",
    "\n",
    "```\n",
    "└─ COCO\n",
    "    ├── val2017\n",
    "    └── annotations\n",
    "        ├── instances_val2017.json\n",
    "        ├── panoptic_val2017.json\n",
    "        └── person_keypoints_val2017.json\n",
    "```\n",
    "\n",
    "下面的样例通过`CocoDataset`接口加载COCO2017数据集，分别演示了将任务指定为目标检测（Detection）、背景分割（Stuff）、关键点检测（Keypoint）和全景分割（Panoptic）时获取到的不同数据。\n",
    "\n",
    "```python\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "DATA_DIR = \"COCO/val2017/\"\n",
    "ANNOTATION_FILE = \"COCO/annotations/instances_val2017.json\"\n",
    "KEYPOINT_FILE = \"COCO/annotations/person_keypoints_val2017.json\"\n",
    "PANOPTIC_FILE = \"COCO/annotations/panoptic_val2017.json\"\n",
    "\n",
    "dataset = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task=\"Detection\", num_samples=1)\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Detection:\", data.keys())\n",
    "\n",
    "dataset = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task=\"Stuff\", num_samples=1)\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Stuff:\", data.keys())\n",
    "\n",
    "dataset = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task=\"Keypoint\", num_samples=1)\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Keypoint:\", data.keys())\n",
    "\n",
    "dataset = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task=\"Panoptic\", num_samples=1)\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Panoptic:\", data.keys())\n",
    "```\n",
    "\n",
    "输出结果：\n",
    "\n",
    "```text\n",
    "Detection: dict_keys(['image', 'bbox', 'category_id', 'iscrowd'])\n",
    "Stuff: dict_keys(['image', 'segmentation', 'iscrowd'])\n",
    "Keypoint: dict_keys(['image', 'keypoints', 'num_keypoints'])\n",
    "Panoptic: dict_keys(['image', 'bbox', 'category_id', 'iscrowd', 'area'])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 特定格式数据集加载\n",
    "\n",
    "下面将介绍几种特定格式数据集文件的加载方式。\n",
    "\n",
    "### MindRecord数据格式\n",
    "\n",
    "MindRecord是MindSpore定义的一种数据格式，使用MindRecord能够获得更好的性能提升。\n",
    "\n",
    "> 阅读[数据格式转换](https://www.mindspore.cn/doc/programming_guide/zh-CN/master/dataset_conversion.html)章节，了解如何将数据集转化为MindSpore数据格式。\n",
    "\n",
    "执行本例之前需下载对应的测试数据`test_mindrecord.zip`并解压到指定位置，执行如下命令："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-12-09 20:21:48--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/test_mindrecord.zip\n",
      "Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172\n",
      "Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 60583 (59K) [application/zip]\n",
      "Saving to: ‘test_mindrecord.zip’\n",
      "\n",
      "test_mindrecord.zip 100%[===================>]  59.16K  --.-KB/s    in 0s      \n",
      "\n",
      "2020-12-09 20:21:48 (331 MB/s) - ‘test_mindrecord.zip’ saved [60583/60583]\n",
      "\n",
      "Archive:  ./test_mindrecord.zip\n",
      "  inflating: ./datasets/mindspore_dataset_loading/test.mindrecord.db  \n",
      "  inflating: ./datasets/mindspore_dataset_loading/test.mindrecord  \n",
      "./datasets/mindspore_dataset_loading/\n",
      "├── test.mindrecord\n",
      "└── test.mindrecord.db\n",
      "\n",
      "0 directories, 2 files\n"
     ]
    }
   ],
   "source": [
    "!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/test_mindrecord.zip\n",
    "!unzip -o ./test_mindrecord.zip -d ./datasets/mindspore_dataset_loading/\n",
    "!tree ./datasets/mindspore_dataset_loading/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的样例通过`MindDataset`接口加载MindRecord文件，并展示已加载数据的标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['chinese', 'english'])\n",
      "dict_keys(['chinese', 'english'])\n",
      "dict_keys(['chinese', 'english'])\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "DATA_FILE = [\"./datasets/mindspore_dataset_loading/test.mindrecord\"]\n",
    "mindrecord_dataset = ds.MindDataset(DATA_FILE)\n",
    "\n",
    "for data in mindrecord_dataset.create_dict_iterator(output_numpy=True):\n",
    "    print(data.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Manifest数据格式\n",
    "\n",
    "Manifest是华为ModelArts支持的数据格式文件，详细说明请参见[Manifest文档](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0009.html)。\n",
    "\n",
    "本次示例需下载测试数据`test_manifest.zip`并将其解压到指定位置，执行如下命令："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-12-09 20:21:49--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/test_manifest.zip\n",
      "Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172\n",
      "Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 440877 (431K) [application/zip]\n",
      "Saving to: ‘test_manifest.zip’\n",
      "\n",
      "test_manifest.zip   100%[===================>] 430.54K  --.-KB/s    in 0.005s  \n",
      "\n",
      "2020-12-09 20:21:49 (89.2 MB/s) - ‘test_manifest.zip’ saved [440877/440877]\n",
      "\n",
      "Archive:  ./test_manifest.zip\n",
      "  inflating: ./datasets/mindspore_dataset_loading/test_manifest/eval/1.JPEG  \n",
      "  inflating: ./datasets/mindspore_dataset_loading/test_manifest/eval/2.JPEG  \n",
      "  inflating: ./datasets/mindspore_dataset_loading/test_manifest/test_manifest.json  \n",
      "   creating: ./datasets/mindspore_dataset_loading/test_manifest/train/\n",
      "  inflating: ./datasets/mindspore_dataset_loading/test_manifest/train/1.JPEG  \n",
      "  inflating: ./datasets/mindspore_dataset_loading/test_manifest/train/2.JPEG  \n",
      "./datasets/mindspore_dataset_loading/test_manifest/\n",
      "├── eval\n",
      "│   ├── 1.JPEG\n",
      "│   └── 2.JPEG\n",
      "├── test_manifest.json\n",
      "└── train\n",
      "    ├── 1.JPEG\n",
      "    └── 2.JPEG\n",
      "\n",
      "2 directories, 5 files\n"
     ]
    }
   ],
   "source": [
    "!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/test_manifest.zip\n",
    "!unzip -o ./test_manifest.zip -d ./datasets/mindspore_dataset_loading/test_manifest/\n",
    "!tree ./datasets/mindspore_dataset_loading/test_manifest/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的样例通过`ManifestDataset`接口加载Manifest文件`test_manifest.json`，并展示已加载数据的标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "import os\n",
    "\n",
    "DATA_FILE = \"./datasets/mindspore_dataset_loading/test_manifest/test_manifest.json\"\n",
    "manifest_dataset = ds.ManifestDataset(DATA_FILE)\n",
    "\n",
    "for data in manifest_dataset.create_dict_iterator():\n",
    "    print(data[\"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TFRecord数据格式\n",
    "\n",
    "TFRecord是TensorFlow定义的一种二进制数据文件格式。\n",
    "\n",
    "下面的样例通过`TFRecordDataset`接口加载TFRecord文件，并介绍了两种不同的数据集格式设定方案。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下载`tfrecord`测试数据`test_tftext.zip`并解压到指定位置，执行如下命令："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-12-09 20:21:50--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/test_tftext.zip\n",
      "Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172\n",
      "Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 522 [application/zip]\n",
      "Saving to: ‘test_tftext.zip’\n",
      "\n",
      "test_tftext.zip     100%[===================>]     522  --.-KB/s    in 0s      \n",
      "\n",
      "2020-12-09 20:21:50 (31.9 MB/s) - ‘test_tftext.zip’ saved [522/522]\n",
      "\n",
      "Archive:  ./test_tftext.zip\n",
      "  inflating: ./datasets/mindspore_dataset_loading/test_tfrecord/test_tftext.tfrecord  \n",
      "./datasets/mindspore_dataset_loading/test_tfrecord/\n",
      "└── test_tftext.tfrecord\n",
      "\n",
      "0 directories, 1 file\n"
     ]
    }
   ],
   "source": [
    "!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/test_tftext.zip\n",
    "!unzip -o ./test_tftext.zip -d ./datasets/mindspore_dataset_loading/test_tfrecord/\n",
    "!tree ./datasets/mindspore_dataset_loading/test_tfrecord/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 传入数据集路径或TFRecord文件列表，本例使用`test_tftext.tfrecord`，创建`TFRecordDataset`对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['chinese', 'line', 'words'])\n",
      "dict_keys(['chinese', 'line', 'words'])\n",
      "dict_keys(['chinese', 'line', 'words'])\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "DATA_FILE = \"./datasets/mindspore_dataset_loading/test_tfrecord/test_tftext.tfrecord\"\n",
    "tfrecord_dataset = ds.TFRecordDataset(DATA_FILE)\n",
    "\n",
    "for data in tfrecord_dataset.create_dict_iterator():\n",
    "    print(data.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 用户可以通过编写Schema文件或创建Schema对象，设定数据集格式及特征。\n",
    "\n",
    "    - 编写Schema文件\n",
    "\n",
    "        将数据集格式和特征按JSON格式写入Schema文件，示例如下：\n",
    "\n",
    "        ```json\n",
    "        {\n",
    "         \"columns\": {\n",
    "             \"image\": {\n",
    "                 \"type\": \"uint8\",\n",
    "                 \"rank\": 1\n",
    "                 },\n",
    "             \"label\" : {\n",
    "                 \"type\": \"string\",\n",
    "                 \"rank\": 1\n",
    "                 }\n",
    "             \"id\" : {\n",
    "                 \"type\": \"int64\",\n",
    "                 \"rank\": 0\n",
    "                 }\n",
    "             }\n",
    "         }\n",
    "        ```\n",
    "\n",
    "        - `columns`：列信息字段，需要根据数据集的实际列名定义。上面的示例中，数据集列为`image`、`label`和`id`。\n",
    "\n",
    "        然后在创建`TFRecordDataset`时将Schema文件路径传入。\n",
    "\n",
    "        ```python\n",
    "        SCHEMA_DIR = \"dataset_schema_path/schema.json\"\n",
    "        tfrecord_dataset = ds.TFRecordDataset(DATA_FILE, schema=SCHEMA_DIR)\n",
    "        ```\n",
    "\n",
    "    - 创建Schema对象\n",
    "\n",
    "        创建Schema对象，为其添加自定义字段，然后在创建数据集对象时传入。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import dtype as mstype\n",
    "schema = ds.Schema()\n",
    "schema.add_column('image', de_type=mstype.uint8)\n",
    "schema.add_column('label', de_type=mstype.int32)\n",
    "tfrecord_dataset = ds.TFRecordDataset(DATA_FILE, schema=schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### NumPy数据格式\n",
    "\n",
    "如果所有数据已经读入内存，可以直接使用`NumpySlicesDataset`类将其加载。\n",
    "\n",
    "下面的样例分别介绍了通过`NumpySlicesDataset`加载arrays数据、 list数据和dict数据的方式。\n",
    "\n",
    "- 加载NumPy arrays数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.89286015 0.33197981] [0.33540785]\n",
      "[0.82122912 0.04169663] [0.62251943]\n",
      "[0.10765668 0.59505206] [0.43814143]\n",
      "[0.52981736 0.41880743] [0.73588211]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "np.random.seed(6)\n",
    "features, labels = np.random.sample((4, 2)), np.random.sample((4, 1))\n",
    "\n",
    "data = (features, labels)\n",
    "dataset = ds.NumpySlicesDataset(data, column_names=[\"col1\", \"col2\"], shuffle=False)\n",
    "\n",
    "for data in dataset:\n",
    "    print(data[0], data[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 加载Python list数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2]\n",
      "[3 4]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "data1 = [[1, 2], [3, 4]]\n",
    "\n",
    "dataset = ds.NumpySlicesDataset(data1, column_names=[\"col1\"], shuffle=False)\n",
    "\n",
    "for data in dataset:\n",
    "    print(data[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 加载Python dict数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'col1': Tensor(shape=[], dtype=Int64, value= 1), 'col2': Tensor(shape=[], dtype=Int64, value= 3)}\n",
      "{'col1': Tensor(shape=[], dtype=Int64, value= 2), 'col2': Tensor(shape=[], dtype=Int64, value= 4)}\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "data1 = {\"a\": [1, 2], \"b\": [3, 4]}\n",
    "\n",
    "dataset = ds.NumpySlicesDataset(data1, column_names=[\"col1\", \"col2\"], shuffle=False)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CSV数据格式\n",
    "\n",
    "下面的样例通过`CSVDataset`加载CSV格式数据集文件，并展示了已加载数据的`keys`。\n",
    "\n",
    "下载测试数据`test_csv.zip`并解压到指定位置，执行如下命令："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-12-09 20:21:50--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/test_csv.zip\n",
      "Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172\n",
      "Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 344 [application/zip]\n",
      "Saving to: ‘test_csv.zip’\n",
      "\n",
      "test_csv.zip        100%[===================>]     344  --.-KB/s    in 0s      \n",
      "\n",
      "2020-12-09 20:21:50 (19.3 MB/s) - ‘test_csv.zip’ saved [344/344]\n",
      "\n",
      "Archive:  ./test_csv.zip\n",
      " extracting: ./datasets/mindspore_dataset_loading/test_csv/test2.csv  \n",
      " extracting: ./datasets/mindspore_dataset_loading/test_csv/test1.csv  \n",
      "./datasets/mindspore_dataset_loading/test_csv/\n",
      "├── test1.csv\n",
      "└── test2.csv\n",
      "\n",
      "0 directories, 2 files\n"
     ]
    }
   ],
   "source": [
    "!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/test_csv.zip\n",
    "!unzip -o ./test_csv.zip -d ./datasets/mindspore_dataset_loading/test_csv/\n",
    "!tree ./datasets/mindspore_dataset_loading/test_csv/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "传入数据集路径或csv文件列表，Text格式数据集文件的加载方式与CSV文件类似。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['a', 'b', 'c', 'd'])\n",
      "dict_keys(['a', 'b', 'c', 'd'])\n",
      "dict_keys(['a', 'b', 'c', 'd'])\n",
      "dict_keys(['a', 'b', 'c', 'd'])\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "DATA_FILE = [\"./datasets/mindspore_dataset_loading/test_csv/test1.csv\",\"./datasets/mindspore_dataset_loading/test_csv/test2.csv\"]\n",
    "csv_dataset = ds.CSVDataset(DATA_FILE)\n",
    "\n",
    "for data in csv_dataset.create_dict_iterator(output_numpy=True):\n",
    "    print(data.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义数据集加载\n",
    "\n",
    "对于目前MindSpore不支持直接加载的数据集，可以通过构造`GeneratorDataset`对象实现自定义方式的加载，或者将其转换成MindRecord数据格式。下面分别展示几种不同的自定义数据集加载方法，为了便于对比，生成的随机数据保持相同。\n",
    "\n",
    "### 构造数据集生成函数\n",
    "\n",
    "构造生成函数定义数据返回方式，再使用此函数构建自定义数据集对象。此方法适用于简单场景。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.36510558 0.45120592] [0.78888122]\n",
      "[0.49606035 0.07562207] [0.38068183]\n",
      "[0.57176158 0.28963401] [0.16271622]\n",
      "[0.30880446 0.37487617] [0.54738768]\n",
      "[0.81585667 0.96883469] [0.77994068]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "np.random.seed(58)\n",
    "data = np.random.sample((5, 2))\n",
    "label = np.random.sample((5, 1))\n",
    "\n",
    "def GeneratorFunc():\n",
    "    for i in range(5):\n",
    "        yield (data[i], label[i])\n",
    "\n",
    "dataset = ds.GeneratorDataset(GeneratorFunc, [\"data\", \"label\"])\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(data[\"data\"], data[\"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造可迭代的数据集类\n",
    "\n",
    "构造数据集类实现`__iter__`和`__next__`方法，再使用此类的对象构建自定义数据集对象。相比于直接定义生成函数，使用数据集类能够实现更多的自定义功能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.36510558 0.45120592] [0.78888122]\n",
      "[0.49606035 0.07562207] [0.38068183]\n",
      "[0.57176158 0.28963401] [0.16271622]\n",
      "[0.30880446 0.37487617] [0.54738768]\n",
      "[0.81585667 0.96883469] [0.77994068]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "class IterDatasetGenerator:\n",
    "    def __init__(self):\n",
    "        np.random.seed(58)\n",
    "        self.__index = 0\n",
    "        self.__data = np.random.sample((5, 2))\n",
    "        self.__label = np.random.sample((5, 1))\n",
    "\n",
    "    def __next__(self):\n",
    "        if self.__index >= len(self.__data):\n",
    "            raise StopIteration\n",
    "        else:\n",
    "            item = (self.__data[self.__index], self.__label[self.__index])\n",
    "            self.__index += 1\n",
    "            return item\n",
    "\n",
    "    def __iter__(self):\n",
    "        return self\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.__data)\n",
    "\n",
    "dataset_generator = IterDatasetGenerator()\n",
    "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label\"], shuffle=False)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(data[\"data\"], data[\"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造可随机访问的数据集类\n",
    "\n",
    "构造数据集类实现`__getitem__`方法，再使用此类的对象构建自定义数据集对象。此方法可以用于实现分布式训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.36510558 0.45120592] [0.78888122]\n",
      "[0.49606035 0.07562207] [0.38068183]\n",
      "[0.57176158 0.28963401] [0.16271622]\n",
      "[0.30880446 0.37487617] [0.54738768]\n",
      "[0.81585667 0.96883469] [0.77994068]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "class GetDatasetGenerator:\n",
    "    def __init__(self):\n",
    "        np.random.seed(58)\n",
    "        self.__data = np.random.sample((5, 2))\n",
    "        self.__label = np.random.sample((5, 1))\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return (self.__data[index], self.__label[index])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.__data)\n",
    "\n",
    "dataset_generator = GetDatasetGenerator()\n",
    "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label\"], shuffle=False)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(data[\"data\"], data[\"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果用户希望实现分布式训练，则需要在此方式的基础上，在采样器类中实现`__iter__`方法，每次返回采样数据的索引。需要补充的代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.36510558 0.45120592] [0.78888122]\n",
      "[0.57176158 0.28963401] [0.16271622]\n",
      "[0.81585667 0.96883469] [0.77994068]\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "class MySampler():\n",
    "    def __init__(self, dataset, local_rank, world_size):\n",
    "        self.__num_data = len(dataset)\n",
    "        self.__local_rank = local_rank\n",
    "        self.__world_size = world_size\n",
    "        self.samples_per_rank = int(math.ceil(self.__num_data / float(self.__world_size)))\n",
    "        self.total_num_samples = self.samples_per_rank * self.__world_size\n",
    "\n",
    "    def __iter__(self):\n",
    "        indices = list(range(self.__num_data))\n",
    "        indices.extend(indices[:self.total_num_samples-len(indices)])\n",
    "        indices = indices[self.__local_rank:self.total_num_samples:self.__world_size]\n",
    "        return iter(indices)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.samples_per_rank\n",
    "\n",
    "dataset_generator = GetDatasetGenerator()\n",
    "sampler = MySampler(dataset_generator, local_rank=0, world_size=2)\n",
    "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label\"], shuffle=False, sampler=sampler)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(data[\"data\"], data[\"label\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore-1.0.1",
   "language": "python",
   "name": "mindspore-1.0.1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
